{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "import numpy as np\n",
    "import xarray as xr\n",
    "\n",
    "climate_indices_home_path = \"/home/james/git/climate_indices\"\n",
    "if climate_indices_home_path not in sys.path:\n",
    "    sys.path.insert(climate_indices_home_path)\n",
    "from climate_indices import compute, indices\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "netcdf_precip = \"/data/datasets/nclimgrid/nclimgrid_lowres_prcp.nc\"\n",
    "netcdf_gamma = \"/data/datasets/nclimgrid/nclimgrid_lowres_gamma.nc\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Open the precipitation dataset as an xarray Dataset object."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_prcp = xr.open_dataset(netcdf_precip)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['lat', 'lon', 'time']"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(ds_prcp.dims)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get the precipitation data and reshape the array to have the time dimension as the inner-most axis:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "da_prcp = ds_prcp[\"prcp\"].transpose(\"lat\", \"lon\", \"time\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "initial_year = int(da_prcp[\"time\"][0].dt.year)\n",
    "calibration_year_initial = 1900\n",
    "calibration_year_final = 2000\n",
    "period_times = 12\n",
    "total_lats = da_prcp.shape[0]\n",
    "total_lons = da_prcp.shape[1]\n",
    "fitting_shape = (total_lats, total_lons, period_times)\n",
    "scales = [1]  # , 2, 3, 6, 9, 12, 24]\n",
    "periodicity = compute.Periodicity.monthly"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define a function that can be used to compute the gamma fitting parameters for a particular month scale:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_gammas(\n",
    "    da_precip: xr.DataArray,\n",
    "    gamma_coords: dict,\n",
    "    scale: int,\n",
    "    calibration_year_initial,\n",
    "    calibration_year_final,\n",
    "    periodicity: compute.Periodicity,\n",
    ") -> (xr.DataArray, xr.DataArray):\n",
    "    initial_year = int(da_precip[\"time\"][0].dt.year)\n",
    "    if periodicity == compute.Periodicity.monthly:\n",
    "        period_times = 12\n",
    "    elif periodicity == compute.Periodicity.daily:\n",
    "        period_times = 366\n",
    "    total_lats = da_precip.shape[0]\n",
    "    total_lons = da_precip.shape[1]\n",
    "    fitting_shape = (total_lats, total_lons, period_times)\n",
    "    alphas = np.full(shape=fitting_shape, fill_value=np.NaN)\n",
    "    betas = np.full(shape=fitting_shape, fill_value=np.NaN)\n",
    "\n",
    "    # loop over the grid cells and compute the gamma parameters for each\n",
    "    for lat_index in range(total_lats):\n",
    "        for lon_index in range(total_lons):\n",
    "            # get the precipitation values for the lat/lon grid cell\n",
    "            values = da_precip[lat_index, lon_index]\n",
    "\n",
    "            # skip over this grid cell if all NaN values\n",
    "            if (np.ma.is_masked(values) and values.mask.all()) or np.all(np.isnan(values)):\n",
    "                continue\n",
    "\n",
    "            # convolve to scale\n",
    "            scaled_values = compute.scale_values(\n",
    "                values,\n",
    "                scale=scale,\n",
    "                periodicity=periodicity,\n",
    "            )\n",
    "\n",
    "            # compute the fitting parameters on the scaled data\n",
    "            (\n",
    "                alphas[lat_index, lon_index],\n",
    "                betas[lat_index, lon_index],\n",
    "            ) = compute.gamma_parameters(\n",
    "                scaled_values,\n",
    "                data_start_year=initial_year,\n",
    "                calibration_start_year=calibration_year_initial,\n",
    "                calibration_end_year=calibration_year_final,\n",
    "                periodicity=periodicity,\n",
    "            )\n",
    "\n",
    "    alpha_attrs = {\n",
    "        \"description\": \"shape parameter of the gamma distribution (also referred to as the concentration) \"\n",
    "        + f\"computed from the {scale}-month scaled precipitation values\",\n",
    "    }\n",
    "    da_alpha = xr.DataArray(\n",
    "        data=alphas,\n",
    "        coords=gamma_coords,\n",
    "        dims=tuple(gamma_coords.keys()),\n",
    "        name=f\"alpha_{str(scale).zfill(2)}\",\n",
    "        attrs=alpha_attrs,\n",
    "    )\n",
    "    beta_attrs = {\n",
    "        \"description\": \"1 / scale of the distribution (also referred to as the rate) \"\n",
    "        + f\"computed from the {scale}-month scaled precipitation values\",\n",
    "    }\n",
    "    da_beta = xr.DataArray(\n",
    "        data=betas,\n",
    "        coords=gamma_coords,\n",
    "        dims=tuple(gamma_coords.keys()),\n",
    "        name=f\"beta_{str(scale).zfill(2)}\",\n",
    "        attrs=beta_attrs,\n",
    "    )\n",
    "\n",
    "    return da_alpha, da_beta"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define a function that can be used to compute the SPI for a particular month scale:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_spi_gamma(\n",
    "    da_precip: xr.DataArray,\n",
    "    da_alpha: xr.DataArray,\n",
    "    da_beta: xr.DataArray,\n",
    "    scale: int,\n",
    "    periodicity: compute.Periodicity,\n",
    ") -> xr.DataArray:\n",
    "    initial_year = int(da_precip[\"time\"][0].dt.year)\n",
    "    total_lats = da_precip.shape[0]\n",
    "    total_lons = da_precip.shape[1]\n",
    "    spi = np.full(shape=da_precip.shape, fill_value=np.NaN)\n",
    "\n",
    "    for lat_index in range(total_lats):\n",
    "        for lon_index in range(total_lons):\n",
    "            # get the values for the lat/lon grid cell\n",
    "            values = da_precip[lat_index, lon_index]\n",
    "\n",
    "            # skip over this grid cell if all NaN values\n",
    "            if (np.ma.is_masked(values) and values.mask.all()) or np.all(np.isnan(values)):\n",
    "                continue\n",
    "\n",
    "            gamma_parameters = {\n",
    "                \"alphas\": da_alpha[lat_index, lon_index],\n",
    "                \"betas\": da_beta[lat_index, lon_index],\n",
    "            }\n",
    "\n",
    "            # compute the SPI\n",
    "            spi[lat_index, lon_index] = indices.spi(\n",
    "                values,\n",
    "                scale=scale,\n",
    "                distribution=indices.Distribution.gamma,\n",
    "                data_start_year=initial_year,\n",
    "                calibration_year_initial=calibration_year_initial,\n",
    "                calibration_year_final=calibration_year_final,\n",
    "                periodicity=compute.Periodicity.monthly,\n",
    "                fitting_params=gamma_parameters,\n",
    "            )\n",
    "\n",
    "    # build a DataArray for this scale's SPI\n",
    "    da_spi = xr.DataArray(\n",
    "        data=spi,\n",
    "        coords=da_precip.coords,\n",
    "        dims=(\"lat\", \"lon\", \"time\"),\n",
    "        name=f\"spi_gamma_{str(scale).zfill(2)}\",\n",
    "    )\n",
    "    da_spi.attrs = {\n",
    "        \"description\": f\"SPI ({scale}-{periodicity} gamma) computed from monthly precipitation \"\n",
    "        + f\"data for the period {da_precip.time[0]} through {da_precip.time[-1]} using a \"\n",
    "        + f\"calibration period from {calibration_year_initial} through {calibration_year_final}\",\n",
    "        \"valid_min\": -3.09,\n",
    "        \"valid_max\": 3.09,\n",
    "        \"long_name\": f\"{scale}-{periodicity} SPI(gamma)\",\n",
    "        \"calibration_year_initial\": calibration_year_initial,\n",
    "        \"calibration_year_final\": calibration_year_final,\n",
    "    }\n",
    "\n",
    "    return da_spi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copy the attributes from the precipitation dataset that will be applicable to the corresponding gamma fitting parameters and SPI datasets:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "attrs_to_copy = [\n",
    "    \"Conventions\",\n",
    "    \"ncei_template_version\",\n",
    "    \"naming_authority\",\n",
    "    \"standard_name_vocabulary\",\n",
    "    \"institution\",\n",
    "    \"geospatial_lat_min\",\n",
    "    \"geospatial_lat_max\",\n",
    "    \"geospatial_lon_min\",\n",
    "    \"geospatial_lon_max\",\n",
    "    \"geospatial_lat_units\",\n",
    "    \"geospatial_lon_units\",\n",
    "]\n",
    "global_attrs = {key: value for (key, value) in ds_prcp.attrs.items() if key in attrs_to_copy}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compute the gamma fitting parameters for all scales and add these into a Dataset that we'll write to NetCDF:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "if periodicity == compute.Periodicity.monthly:\n",
    "    period_times = 12\n",
    "    time_coord_name = \"month\"\n",
    "elif periodicity == compute.Periodicity.daily:\n",
    "    period_times = 366\n",
    "    time_coord_name = \"day\"\n",
    "gamma_coords = {\n",
    "    \"lat\": ds_prcp.lat,\n",
    "    \"lon\": ds_prcp.lon,\n",
    "    time_coord_name: range(period_times),\n",
    "}\n",
    "ds_gamma = xr.Dataset(\n",
    "    coords=gamma_coords,\n",
    "    attrs=global_attrs,\n",
    ")\n",
    "for scale in scales:\n",
    "    var_name_alpha = f\"alpha_{str(scale).zfill(2)}\"\n",
    "    var_name_beta = f\"beta_{str(scale).zfill(2)}\"\n",
    "    da_alpha, da_beta = compute_gammas(\n",
    "        da_prcp,\n",
    "        gamma_coords,\n",
    "        scale,\n",
    "        calibration_year_initial,\n",
    "        calibration_year_final,\n",
    "        periodicity,\n",
    "    )\n",
    "    ds_gamma[var_name_alpha] = da_alpha\n",
    "    ds_gamma[var_name_beta] = da_beta\n",
    "\n",
    "ds_gamma.to_netcdf(netcdf_gamma)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_gamma"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compute the SPI using the pre-computed gamma fitting parameters for all scales and add these into a SPI(gamma) Dataset that we'll write to NetCDF:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "ds_spi = xr.Dataset(\n",
    "    coords=ds_prcp.coords,\n",
    "    attrs=global_attrs,\n",
    ")\n",
    "for scale in scales:\n",
    "    var_name_alpha = f\"alpha_{str(scale).zfill(2)}\"\n",
    "    var_name_beta = f\"beta_{str(scale).zfill(2)}\"\n",
    "    da_spi = compute_spi_gamma(\n",
    "        da_prcp,\n",
    "        ds_gamma[var_name_alpha],\n",
    "        ds_gamma[var_name_alpha],\n",
    "        scale,\n",
    "        periodicity,\n",
    "    )\n",
    "    ds_spi[f\"spi_gamma_{str(scale).zfill(2)}\"] = da_spi\n",
    "\n",
    "netcdf_spi = \"/home/james/data/nclimgrid/nclimgrid_lowres_spi_gamma.nc\"\n",
    "ds_spi.to_netcdf(netcdf_spi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_spi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Plot a time step to validate that the SPI values look reasonable:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_spi[f\"spi_gamma_{str(scale).zfill(2)}\"].isel(time=500).plot()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define a function that computes SPI without using pre-computed fitting parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_spi_gamma_without_fittings(\n",
    "    da_precip: xr.DataArray,\n",
    "    scale: int,\n",
    "    periodicity: compute.Periodicity,\n",
    ") -> xr.DataArray:\n",
    "    initial_year = int(da_precip[\"time\"][0].dt.year)\n",
    "    total_lats = da_precip.shape[0]\n",
    "    total_lons = da_precip.shape[1]\n",
    "    spi = np.full(shape=da_precip.shape, fill_value=np.NaN)\n",
    "\n",
    "    for lat_index in range(total_lats):\n",
    "        for lon_index in range(total_lons):\n",
    "            # get the values for the lat/lon grid cell\n",
    "            values = da_precip[lat_index, lon_index]\n",
    "\n",
    "            # skip over this grid cell if all NaN values\n",
    "            if (np.ma.is_masked(values) and values.mask.all()) or np.all(np.isnan(values)):\n",
    "                continue\n",
    "\n",
    "            # compute the SPI\n",
    "            spi[lat_index, lon_index] = indices.spi(\n",
    "                values,\n",
    "                scale=scale,\n",
    "                distribution=indices.Distribution.gamma,\n",
    "                data_start_year=initial_year,\n",
    "                calibration_year_initial=calibration_year_initial,\n",
    "                calibration_year_final=calibration_year_final,\n",
    "                periodicity=compute.Periodicity.monthly,\n",
    "            )\n",
    "\n",
    "    # build a DataArray for this scale's SPI\n",
    "    da_spi = xr.DataArray(\n",
    "        data=spi,\n",
    "        coords=da_precip.coords,\n",
    "        dims=(\"lat\", \"lon\", \"time\"),\n",
    "        name=f\"spi_gamma_{str(scale).zfill(2)}\",\n",
    "    )\n",
    "    da_spi.attrs = {\n",
    "        \"description\": f\"SPI ({scale}-{periodicity} gamma) computed from monthly precipitation \"\n",
    "        + f\"data for the period {da_precip.time[0]} through {da_precip.time[-1]} using a \"\n",
    "        + f\"calibration period from {calibration_year_initial} through {calibration_year_final}\",\n",
    "        \"valid_min\": -3.09,\n",
    "        \"valid_max\": 3.09,\n",
    "        \"long_name\": f\"{scale}-{periodicity} SPI(gamma)\",\n",
    "        \"calibration_year_initial\": calibration_year_initial,\n",
    "        \"calibration_year_final\": calibration_year_final,\n",
    "    }\n",
    "\n",
    "    return da_spi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compute SPI without pre-computed fitting parameters to see if there's a significant time difference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "ds_spi_no_fittings = xr.Dataset(\n",
    "    coords=ds_prcp.coords,\n",
    "    attrs=global_attrs,\n",
    ")\n",
    "for scale in scales:\n",
    "    da_spi = compute_spi_gamma_without_fittings(\n",
    "        da_prcp,\n",
    "        scale,\n",
    "        periodicity,\n",
    "    )\n",
    "    ds_spi_no_fittings[f\"spi_gamma_{str(scale).zfill(2)}\"] = da_spi\n",
    "\n",
    "netcdf_spi = \"/home/james/data/nclimgrid/nclimgrid_lowres_spi_gamma_no_fittings.nc\"\n",
    "ds_spi_no_fittings.to_netcdf(netcdf_spi)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
